{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# APPLYING BERT CLASSIFIERS TO HATECHECK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialise relevant packages\n",
    "\n",
    "# Basics\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Preprocessing\n",
    "import torch\n",
    "\n",
    "# Modelling\n",
    "from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments\n",
    "\n",
    "# Evaluation\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Test Suite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>functionality</th>\n",
       "      <th>case_id</th>\n",
       "      <th>test_case</th>\n",
       "      <th>label_gold</th>\n",
       "      <th>target_ident</th>\n",
       "      <th>direction</th>\n",
       "      <th>focus_words</th>\n",
       "      <th>focus_lemma</th>\n",
       "      <th>ref_case_id</th>\n",
       "      <th>ref_templ_id</th>\n",
       "      <th>templ_id</th>\n",
       "      <th>case_templ</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>derog_neg_emote_h</td>\n",
       "      <td>1</td>\n",
       "      <td>I hate women.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>women</td>\n",
       "      <td>general</td>\n",
       "      <td>hate</td>\n",
       "      <td>hate</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>I hate [IDENTITY_P].</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>derog_neg_emote_h</td>\n",
       "      <td>2</td>\n",
       "      <td>I hate trans people.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>trans people</td>\n",
       "      <td>general</td>\n",
       "      <td>hate</td>\n",
       "      <td>hate</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>I hate [IDENTITY_P].</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>derog_neg_emote_h</td>\n",
       "      <td>3</td>\n",
       "      <td>I hate gay people.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>gay people</td>\n",
       "      <td>general</td>\n",
       "      <td>hate</td>\n",
       "      <td>hate</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>I hate [IDENTITY_P].</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>derog_neg_emote_h</td>\n",
       "      <td>4</td>\n",
       "      <td>I hate black people.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>black people</td>\n",
       "      <td>general</td>\n",
       "      <td>hate</td>\n",
       "      <td>hate</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>I hate [IDENTITY_P].</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>derog_neg_emote_h</td>\n",
       "      <td>5</td>\n",
       "      <td>I hate disabled people.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>disabled people</td>\n",
       "      <td>general</td>\n",
       "      <td>hate</td>\n",
       "      <td>hate</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>I hate [IDENTITY_P].</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3723</th>\n",
       "      <td>spell_leet_h</td>\n",
       "      <td>3897</td>\n",
       "      <td>Gay people are all complete 4ssholes.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>gay people</td>\n",
       "      <td>general</td>\n",
       "      <td>4ssholes</td>\n",
       "      <td>asshole</td>\n",
       "      <td>1232.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>866</td>\n",
       "      <td>[IDENTITY_P] are all complete 4ssholes.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3724</th>\n",
       "      <td>spell_leet_h</td>\n",
       "      <td>3898</td>\n",
       "      <td>Black people are all complete 4ssholes.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>black people</td>\n",
       "      <td>general</td>\n",
       "      <td>4ssholes</td>\n",
       "      <td>asshole</td>\n",
       "      <td>1233.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>866</td>\n",
       "      <td>[IDENTITY_P] are all complete 4ssholes.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3725</th>\n",
       "      <td>spell_leet_h</td>\n",
       "      <td>3899</td>\n",
       "      <td>Disabled people are all complete 4ssholes.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>disabled people</td>\n",
       "      <td>general</td>\n",
       "      <td>4ssholes</td>\n",
       "      <td>asshole</td>\n",
       "      <td>1234.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>866</td>\n",
       "      <td>[IDENTITY_P] are all complete 4ssholes.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3726</th>\n",
       "      <td>spell_leet_h</td>\n",
       "      <td>3900</td>\n",
       "      <td>Muslims are all complete 4ssholes.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>Muslims</td>\n",
       "      <td>general</td>\n",
       "      <td>4ssholes</td>\n",
       "      <td>asshole</td>\n",
       "      <td>1235.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>866</td>\n",
       "      <td>[IDENTITY_P] are all complete 4ssholes.</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3727</th>\n",
       "      <td>spell_leet_h</td>\n",
       "      <td>3901</td>\n",
       "      <td>Immigrants are all complete 4ssholes.</td>\n",
       "      <td>hateful</td>\n",
       "      <td>immigrants</td>\n",
       "      <td>general</td>\n",
       "      <td>4ssholes</td>\n",
       "      <td>asshole</td>\n",
       "      <td>1236.0</td>\n",
       "      <td>256.0</td>\n",
       "      <td>866</td>\n",
       "      <td>[IDENTITY_P] are all complete 4ssholes.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3728 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          functionality  case_id                                    test_case  \\\n",
       "0     derog_neg_emote_h        1                               I hate women.    \n",
       "1     derog_neg_emote_h        2                        I hate trans people.    \n",
       "2     derog_neg_emote_h        3                          I hate gay people.    \n",
       "3     derog_neg_emote_h        4                        I hate black people.    \n",
       "4     derog_neg_emote_h        5                     I hate disabled people.    \n",
       "...                 ...      ...                                          ...   \n",
       "3723       spell_leet_h     3897       Gay people are all complete 4ssholes.    \n",
       "3724       spell_leet_h     3898     Black people are all complete 4ssholes.    \n",
       "3725       spell_leet_h     3899  Disabled people are all complete 4ssholes.    \n",
       "3726       spell_leet_h     3900          Muslims are all complete 4ssholes.    \n",
       "3727       spell_leet_h     3901       Immigrants are all complete 4ssholes.    \n",
       "\n",
       "     label_gold     target_ident direction focus_words focus_lemma  \\\n",
       "0       hateful            women   general        hate        hate   \n",
       "1       hateful     trans people   general        hate        hate   \n",
       "2       hateful       gay people   general        hate        hate   \n",
       "3       hateful     black people   general        hate        hate   \n",
       "4       hateful  disabled people   general        hate        hate   \n",
       "...         ...              ...       ...         ...         ...   \n",
       "3723    hateful       gay people   general    4ssholes     asshole   \n",
       "3724    hateful     black people   general    4ssholes     asshole   \n",
       "3725    hateful  disabled people   general    4ssholes     asshole   \n",
       "3726    hateful          Muslims   general    4ssholes     asshole   \n",
       "3727    hateful       immigrants   general    4ssholes     asshole   \n",
       "\n",
       "      ref_case_id  ref_templ_id  templ_id  \\\n",
       "0             NaN           NaN         1   \n",
       "1             NaN           NaN         1   \n",
       "2             NaN           NaN         1   \n",
       "3             NaN           NaN         1   \n",
       "4             NaN           NaN         1   \n",
       "...           ...           ...       ...   \n",
       "3723       1232.0         256.0       866   \n",
       "3724       1233.0         256.0       866   \n",
       "3725       1234.0         256.0       866   \n",
       "3726       1235.0         256.0       866   \n",
       "3727       1236.0         256.0       866   \n",
       "\n",
       "                                   case_templ  \n",
       "0                        I hate [IDENTITY_P].  \n",
       "1                        I hate [IDENTITY_P].  \n",
       "2                        I hate [IDENTITY_P].  \n",
       "3                        I hate [IDENTITY_P].  \n",
       "4                        I hate [IDENTITY_P].  \n",
       "...                                       ...  \n",
       "3723  [IDENTITY_P] are all complete 4ssholes.  \n",
       "3724  [IDENTITY_P] are all complete 4ssholes.  \n",
       "3725  [IDENTITY_P] are all complete 4ssholes.  \n",
       "3726  [IDENTITY_P] are all complete 4ssholes.  \n",
       "3727  [IDENTITY_P] are all complete 4ssholes.  \n",
       "\n",
       "[3728 rows x 12 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# load test suite\n",
    "hatecheck_df = pd.read_csv('./Data/Test Suite/hatecheck_final_ACL.csv', index_col=0)\n",
    "hatecheck_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert Test Suite to PyTorch Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# separate test suite into text and label series series\n",
    "hatecheck_texts = hatecheck_df.test_case.astype(\"string\").tolist()\n",
    "hatecheck_labels = hatecheck_df.label_gold.replace({'hateful': 1, 'non-hateful': 0}).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import tokenizer (which includes custom special tokens for URLs, mentions and emojis and is the same across training datasets)\n",
    "tokenizer = BertTokenizerFast.from_pretrained(\"./Models/BERT_davidson2017_weighted/Final\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize text series\n",
    "hatecheck_encodings = tokenizer(hatecheck_texts, truncation=True, padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write test suite as PyTorch dataset\n",
    "class HateDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(self.labels[idx])\n",
    "        return item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "hatecheck_dataset = HateDataset(hatecheck_encodings, hatecheck_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Fine-Tuned BERT Models\n",
    "Weighted binary models for Davidson and Founta datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load fine-tuned models\n",
    "models = {}\n",
    "\n",
    "for dataset in ['davidson2017', 'founta2018']:\n",
    "    models['{}_weighted'.format(dataset)] = BertForSequenceClassification.from_pretrained(\"./Models/BERT_{}_weighted/Final\".format(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate trainer objects for each model (already fine-tuned so no longer necessary to specify training and eval data)\n",
    "# output directory is redundant because there is no further training but needs to be specified anyway\n",
    "\n",
    "trainer = {}\n",
    "\n",
    "for model in models:\n",
    "    trainer[model] = Trainer(\n",
    "        model=models[model],         \n",
    "        args=TrainingArguments(\n",
    "            output_dir='./Models/BERT_{}/Test'.format(model),\n",
    "            per_device_eval_batch_size = 64)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply BERT Models to Test Suite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply models to test suite\n",
    "results = {}\n",
    "\n",
    "for model in trainer:\n",
    "    results[model] = trainer[model].predict(hatecheck_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merge BERT Results with Test Suite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write predictions to series\n",
    "label_pred={}\n",
    "\n",
    "for model in trainer:\n",
    "\n",
    "    preds=[]\n",
    "    \n",
    "    for row in results[model][0]:\n",
    "        preds.append(int(np.argmax(row)))\n",
    "    \n",
    "    label_pred[model] = pd.Series(preds)\n",
    "    label_pred[model].replace({1: 'hateful', 0: 'non-hateful'}, inplace=True)\n",
    "    label_pred[model].name = 'pred_BERT_{}'.format(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# concatenate with test suite case ID column for export\n",
    "\n",
    "export_df = hatecheck_df['case_id']\n",
    "\n",
    "for model in trainer:\n",
    "    export_df = pd.concat([export_df, label_pred[model]], axis= 1)\n",
    "    \n",
    "export_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "export_df.to_pickle('./Data/Test Suite/results_BERT_weighted.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
